// Copyright (C) Kumo inc. and its affiliates.
// Author: Jeff.li lijippy@163.com
// All rights reserved.
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
//

#pragma once

#include <nebula/version.h>

#if NEBULA_HAVE_RUNTIME_AVX2_SUPPORTED

#include <immintrin.h>

#endif

#include <cstdint>

#include <nebula/compute/light_array_internal.h>
#include <nebula/compute/util.h>

namespace nebula::compute {

    // Forward declarations only needed for making test functions a friend of the classes in
    // this file.
    //
    enum class BloomFilterBuildStrategy;

    // Implementations are based on xxh3 32-bit algorithm description from:
    // https://github.com/Cyan4973/xxHash/blob/dev/doc/xxhash_spec.md
    //
    class TURBO_EXPORT Hashing32 {
        friend class TestVectorHash;

        template<typename T>
        friend void TestBloomLargeHashHelper(int64_t, int64_t, const std::vector<uint64_t> &,
                                             int64_t, int, T *);

        friend void TestBloomSmall(BloomFilterBuildStrategy, int64_t, int, bool, bool);

    public:
        static void HashMultiColumn(const std::vector<KeyColumnArray> &cols, LightContext *ctx,
                                    uint32_t *out_hash);

        // Clarify the max temp stack usage for HashBatch, which might be necessary for the
        // caller to be aware of at compile time to reserve enough stack size in advance. The
        // HashBatch implementation uses one uint32 temp vector as a buffer for hash, one uint16
        // temp vector as a buffer for null indices and one uint32 temp vector as a buffer for
        // null hash, all are of size kMiniBatchLength. Plus extra kMiniBatchLength to cope with
        // stack padding and aligning.
        static constexpr auto kHashBatchTempStackUsage =
                (sizeof(uint32_t) + sizeof(uint16_t) + sizeof(uint32_t) + /*extra=*/1) *
                util::MiniBatch::kMiniBatchLength;

        static turbo::Status HashBatch(const ExecBatch &key_batch, uint32_t *hashes,
                                       std::vector<KeyColumnArray> &column_arrays,
                                       int64_t hardware_flags, util::TempVectorStack *temp_stack,
                                       int64_t start_row, int64_t num_rows);

        static void HashFixed(int64_t hardware_flags, bool combine_hashes, uint32_t num_keys,
                              uint64_t key_length, const uint8_t *keys, uint32_t *hashes,
                              uint32_t *temp_hashes_for_combine);

    private:
        static const uint32_t PRIME32_1 = 0x9E3779B1;
        static const uint32_t PRIME32_2 = 0x85EBCA77;
        static const uint32_t PRIME32_3 = 0xC2B2AE3D;
        static const uint32_t PRIME32_4 = 0x27D4EB2F;
        static const uint32_t PRIME32_5 = 0x165667B1;
        static const uint32_t kCombineConst = 0x9e3779b9UL;
        static const int64_t kStripeSize = 4 * sizeof(uint32_t);

        static void HashVarLen(int64_t hardware_flags, bool combine_hashes, uint32_t num_rows,
                               const uint32_t *offsets, const uint8_t *concatenated_keys,
                               uint32_t *hashes, uint32_t *temp_hashes_for_combine);

        static void HashVarLen(int64_t hardware_flags, bool combine_hashes, uint32_t num_rows,
                               const uint64_t *offsets, const uint8_t *concatenated_keys,
                               uint32_t *hashes, uint32_t *temp_hashes_for_combine);

        static inline uint32_t Avalanche(uint32_t acc) {
            acc ^= (acc >> 15);
            acc *= PRIME32_2;
            acc ^= (acc >> 13);
            acc *= PRIME32_3;
            acc ^= (acc >> 16);
            return acc;
        }

        static inline uint32_t Round(uint32_t acc, uint32_t input);

        static inline uint32_t CombineAccumulators(uint32_t acc1, uint32_t acc2, uint32_t acc3,
                                                   uint32_t acc4);

        static inline uint32_t CombineHashesImp(uint32_t previous_hash, uint32_t hash) {
            uint32_t next_hash = previous_hash ^ (hash + kCombineConst + (previous_hash << 6) +
                                                  (previous_hash >> 2));
            return next_hash;
        }

        static inline void ProcessFullStripes(uint64_t num_stripes, const uint8_t *key,
                                              uint32_t *out_acc1, uint32_t *out_acc2,
                                              uint32_t *out_acc3, uint32_t *out_acc4);

        static inline void ProcessLastStripe(uint32_t mask1, uint32_t mask2, uint32_t mask3,
                                             uint32_t mask4, const uint8_t *last_stripe,
                                             uint32_t *acc1, uint32_t *acc2, uint32_t *acc3,
                                             uint32_t *acc4);

        static inline void StripeMask(int i, uint32_t *mask1, uint32_t *mask2, uint32_t *mask3,
                                      uint32_t *mask4);

        template<bool T_COMBINE_HASHES>
        static void HashFixedLenImp(uint32_t num_rows, uint64_t key_length, const uint8_t *keys,
                                    uint32_t *hashes);

        template<typename T, bool T_COMBINE_HASHES>
        static void HashVarLenImp(uint32_t num_rows, const T *offsets,
                                  const uint8_t *concatenated_keys, uint32_t *hashes);

        template<bool T_COMBINE_HASHES>
        static void HashBitImp(int64_t bit_offset, uint32_t num_keys, const uint8_t *keys,
                               uint32_t *hashes);

        static void HashBit(bool combine_hashes, int64_t bit_offset, uint32_t num_keys,
                            const uint8_t *keys, uint32_t *hashes);

        template<bool T_COMBINE_HASHES, typename T>
        static void HashIntImp(uint32_t num_keys, const T *keys, uint32_t *hashes);

        static void HashInt(bool combine_hashes, uint32_t num_keys, uint64_t key_length,
                            const uint8_t *keys, uint32_t *hashes);

#if NEBULA_HAVE_RUNTIME_AVX2_SUPPORTED

        static inline __m256i Avalanche_avx2(__m256i hash);

        static inline __m256i CombineHashesImp_avx2(__m256i previous_hash, __m256i hash);

        template<bool T_COMBINE_HASHES>
        static void AvalancheAll_avx2(uint32_t num_rows, uint32_t *hashes,
                                      const uint32_t *hashes_temp_for_combine);

        static inline __m256i Round_avx2(__m256i acc, __m256i input);

        static inline uint64_t CombineAccumulators_avx2(__m256i acc);

        static inline __m256i StripeMask_avx2(int i, int j);

        template<bool two_equal_lengths>
        static inline __m256i ProcessStripes_avx2(int64_t num_stripes_A, int64_t num_stripes_B,
                                                  __m256i mask_last_stripe, const uint8_t *keys,
                                                  int64_t offset_A, int64_t offset_B);

        template<bool T_COMBINE_HASHES>
        static uint32_t HashFixedLenImp_avx2(uint32_t num_rows, uint64_t key_length,
                                             const uint8_t *keys, uint32_t *hashes,
                                             uint32_t *hashes_temp_for_combine);

        static uint32_t HashFixedLen_avx2(bool combine_hashes, uint32_t num_rows,
                                          uint64_t key_length, const uint8_t *keys,
                                          uint32_t *hashes, uint32_t *hashes_temp_for_combine);

        template<typename T, bool T_COMBINE_HASHES>
        static uint32_t HashVarLenImp_avx2(uint32_t num_rows, const T *offsets,
                                           const uint8_t *concatenated_keys, uint32_t *hashes,
                                           uint32_t *hashes_temp_for_combine);

        static uint32_t HashVarLen_avx2(bool combine_hashes, uint32_t num_rows,
                                        const uint32_t *offsets,
                                        const uint8_t *concatenated_keys, uint32_t *hashes,
                                        uint32_t *hashes_temp_for_combine);

        static uint32_t HashVarLen_avx2(bool combine_hashes, uint32_t num_rows,
                                        const uint64_t *offsets,
                                        const uint8_t *concatenated_keys, uint32_t *hashes,
                                        uint32_t *hashes_temp_for_combine);

#endif
    };

    class TURBO_EXPORT Hashing64 {
        friend class TestVectorHash;

        template<typename T>
        friend void TestBloomLargeHashHelper(int64_t, int64_t, const std::vector<uint64_t> &,
                                             int64_t, int, T *);

        friend void TestBloomSmall(BloomFilterBuildStrategy, int64_t, int, bool, bool);

    public:
        static void HashMultiColumn(const std::vector<KeyColumnArray> &cols, LightContext *ctx,
                                    uint64_t *hashes);

        // Clarify the max temp stack usage for HashBatch, which might be necessary for the
        // caller to be aware of at compile time to reserve enough stack size in advance. The
        // HashBatch implementation uses one uint16 temp vector as a buffer for null indices and
        // one uint64 temp vector as a buffer for null hash, all are of size kMiniBatchLength.
        // Plus extra kMiniBatchLength to cope with stack padding and aligning.
        static constexpr auto kHashBatchTempStackUsage =
                (sizeof(uint16_t) + sizeof(uint64_t) + /*extra=*/1) *
                util::MiniBatch::kMiniBatchLength;

        static turbo::Status HashBatch(const ExecBatch &key_batch, uint64_t *hashes,
                                       std::vector<KeyColumnArray> &column_arrays,
                                       int64_t hardware_flags, util::TempVectorStack *temp_stack,
                                       int64_t start_row, int64_t num_rows);

        static void HashFixed(bool combine_hashes, uint32_t num_keys, uint64_t key_length,
                              const uint8_t *keys, uint64_t *hashes);

    private:
        static const uint64_t PRIME64_1 = 0x9E3779B185EBCA87ULL;
        static const uint64_t PRIME64_2 = 0xC2B2AE3D27D4EB4FULL;
        static const uint64_t PRIME64_3 = 0x165667B19E3779F9ULL;
        static const uint64_t PRIME64_4 = 0x85EBCA77C2B2AE63ULL;
        static const uint64_t PRIME64_5 = 0x27D4EB2F165667C5ULL;
        static const uint32_t kCombineConst = 0x9e3779b9UL;
        static const int64_t kStripeSize = 4 * sizeof(uint64_t);

        static void HashVarLen(bool combine_hashes, uint32_t num_rows, const uint32_t *offsets,
                               const uint8_t *concatenated_keys, uint64_t *hashes);

        static void HashVarLen(bool combine_hashes, uint32_t num_rows, const uint64_t *offsets,
                               const uint8_t *concatenated_keys, uint64_t *hashes);

        static inline uint64_t Avalanche(uint64_t acc);

        static inline uint64_t Round(uint64_t acc, uint64_t input);

        static inline uint64_t CombineAccumulators(uint64_t acc1, uint64_t acc2, uint64_t acc3,
                                                   uint64_t acc4);

        static inline uint64_t CombineHashesImp(uint64_t previous_hash, uint64_t hash) {
            uint64_t next_hash = previous_hash ^ (hash + kCombineConst + (previous_hash << 6) +
                                                  (previous_hash >> 2));
            return next_hash;
        }

        static inline void ProcessFullStripes(uint64_t num_stripes, const uint8_t *key,
                                              uint64_t *out_acc1, uint64_t *out_acc2,
                                              uint64_t *out_acc3, uint64_t *out_acc4);

        static inline void ProcessLastStripe(uint64_t mask1, uint64_t mask2, uint64_t mask3,
                                             uint64_t mask4, const uint8_t *last_stripe,
                                             uint64_t *acc1, uint64_t *acc2, uint64_t *acc3,
                                             uint64_t *acc4);

        static inline void StripeMask(int i, uint64_t *mask1, uint64_t *mask2, uint64_t *mask3,
                                      uint64_t *mask4);

        template<bool T_COMBINE_HASHES>
        static void HashFixedLenImp(uint32_t num_rows, uint64_t key_length, const uint8_t *keys,
                                    uint64_t *hashes);

        template<typename T, bool T_COMBINE_HASHES>
        static void HashVarLenImp(uint32_t num_rows, const T *offsets,
                                  const uint8_t *concatenated_keys, uint64_t *hashes);

        template<bool T_COMBINE_HASHES>
        static void HashBitImp(int64_t bit_offset, uint32_t num_keys, const uint8_t *keys,
                               uint64_t *hashes);

        static void HashBit(bool combine_hashes, int64_t bit_offset, uint32_t num_keys,
                            const uint8_t *keys, uint64_t *hashes);

        template<bool T_COMBINE_HASHES, typename T>
        static void HashIntImp(uint32_t num_keys, const T *keys, uint64_t *hashes);

        static void HashInt(bool combine_hashes, uint32_t num_keys, uint64_t key_length,
                            const uint8_t *keys, uint64_t *hashes);
    };

}  // namespace nebula::compute
